
import pandas as pd
import numpy as np
import jieba
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix,classification_report,roc_auc_score
# from gensim.models.word2vec import Word2Vec
from tensorflow import keras
from nlu_model.ptm.word2vector import Word2vector
from nlu_model.util.pkl_impl import save_pkl, load_pkl
from nlu_model.sim.model.sim_model import SimModel


def load_data(path):
    sentence = []
    label = []
    with open(path) as f:
        for line in f:
            ll = line.strip().split("\t")
            sentence.append([ll[1], ll[2]])
            label.append(int(ll[3]))
    return sentence, label

def single2double(sentence):
    sentence_1 = []
    sentence_2 = []
    for pair in sentence:
        sentence_1.append(pair[0])
        sentence_2.append(pair[1])
    return sentence_1, sentence_2

sentence, label = load_data("./data/sim/atec_sim/atec_nlp_sim_train.csv")

x_train, x_test, y_train, y_test = train_test_split(sentence, label, test_size=0.2,random_state=666)
# x_train, x_test, y_train, y_test = train_test_split(sentence[:100], label[:100], test_size=0.2,random_state=666)

def double_plus(x,y):
    x_new = []
    y_new = []
    for idx in range(len(x)):
        x_new.append(x[idx])
        y_new.append(y[idx])
        if y[idx] == 1:
            x_new.append(x[idx])
            y_new.append(y[idx])
    return x_new, y_new

x_train, y_train = double_plus(x_train, y_train)

x_train_1, x_train_2 = single2double(x_train)
x_test_1, x_test_2 = single2double(x_test)

word2vector = Word2vector()

word2vector.load("./data/ptm/shopping_reviews/w2v_word2idx2020100601.pkl",
                 "./data/ptm/shopping_reviews/w2v_model_metric_2020100601.pkl", 
                 "./data/ptm/shopping_reviews/w2v_model_conf_2020100601.pkl")

model_conf = {"MAX_LEN": 50,
              "w2c_len": 300, 
              "emb_model": word2vector}
train_conf = {"batch_size": 64,
              "epochs": 5, 
              "verbose": 1}
sim_model = SimModel("cdssm_base", model_conf, train_conf)

x_train_1 = sim_model.preprocess(x_train_1)
x_train_2 = sim_model.preprocess(x_train_2)
x_test_1 = sim_model.preprocess(x_test_1)
x_test_2 = sim_model.preprocess(x_test_2)

sim_model.fit([x_train_1, x_train_2], y_train, [x_test_1, x_test_2], y_test)
print(sim_model.evaluate([x_test_1, x_test_2], y_test))
sim_model.save("./data/sim/atec_sim/model_20201016")

# print(cls_model.predict([sentence]))
a = sim_model.pred_vec(["花呗怎么还"])
# print(a)
a = sim_model.predict(["花呗怎么还","花呗应该怎么还"])
print(a)

ytest = []
ypred = []
yprod = []
data = []
with open("./data/sim/atec_sim/atec_nlp_sim_test_0.2.csv") as f:
    for line in f:
        ll = line.strip().split("\t")
        pred = sim_model.predict([ll[1],ll[2]])
        if pred > 0.5:
            ypred.append(1)
        else:
            ypred.append(0)
        ytest.append(int(ll[3]))
        yprod.append(pred)
        data.append([ll[1],ll[2]])
print("auc")
print(roc_auc_score(ytest, yprod))
print("confusion_matrix")
print(confusion_matrix(ytest,ypred))
print("classification_report")
print(classification_report(ytest,ypred))

with open("./data/sim/atec_sim/atec_nlp_sim_test_result.csv", "w") as f:
    for idx in range(len(y_test)):
        f.write("{}\t{}\t{}\t{}\t{}\n".format(data[idx][0], data[idx][1], ytest[idx], ypred[idx], yprod[idx]))




# sim_model = SimModel("lstm_dssm_base", model_conf, train_conf)

# sim_model.fit([x_train_1, x_train_2], y_train, [x_test_1, x_test_2], y_test)
# print(sim_model.evaluate([x_test_1, x_test_2], y_test))
# sim_model.save("./data/sim/atec_sim/model_20201016")

# # print(cls_model.predict([sentence]))
# a = sim_model.pred_vec(["花呗怎么还"])
# # print(a)
# a = sim_model.predict(["花呗怎么还","花呗应该怎么还"])
# print(a)

# sim_model = SimModel("dssm_base", model_conf, train_conf)

# sim_model.fit([x_train_1, x_train_2], y_train, [x_test_1, x_test_2], y_test)
# print(sim_model.evaluate([x_test_1, x_test_2], y_test))
# sim_model.save("./data/sim/atec_sim/model_20201016")

# # print(cls_model.predict([sentence]))
# a = sim_model.pred_vec(["花呗怎么还"])
# # print(a)
# a = sim_model.predict(["花呗怎么还","花呗应该怎么还"])
# print(a)
